{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, cv2\n",
    "import string\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "from vlmeval.smp import ls, load, dump, download_file, encode_image_file_to_base64, md5, mrlines\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import multiprocessing as mp\n",
    "from PIL import Image, ImageFont, ImageDraw\n",
    "\n",
    "font_URL = 'http://opencompass.openxlab.space/utils/Fonts/timesb.ttf'\n",
    "font_file = 'timesb.ttf'\n",
    "if not osp.exists(font_file):\n",
    "    download_file(font_URL)\n",
    "    \n",
    "test_split_URL = 'https://s3-us-east-2.amazonaws.com/prior-datasets/ai2d_test_ids.csv'\n",
    "test_split_file = 'ai2d_test_ids.csv'\n",
    "if not osp.exists(test_split_file):\n",
    "    download_file(test_split_URL)\n",
    "    \n",
    "test_ids = set(mrlines(test_split_file))\n",
    "    \n",
    "def proper_font_size(font_file, wh, text, ratio=1):\n",
    "    font_size = 2\n",
    "    while True:\n",
    "        font = ImageFont.truetype(font_file, font_size)\n",
    "        real_box = font.getbbox(text)\n",
    "        real_wh = (real_box[2] - real_box[0], real_box[3] - real_box[1])\n",
    "        if real_wh[0] > wh[0] * ratio or real_wh[1] > wh[1] * ratio:\n",
    "            break\n",
    "        font_size += 1\n",
    "    return font_size\n",
    "\n",
    "def cover_image(ann_path):\n",
    "    data = load(ann_path)\n",
    "    texts = list(data['text'].values())\n",
    "    raw_img = ann_path.replace('annotations', 'images').replace('.json', '')\n",
    "    tgt_img = raw_img.replace('images', 'images_abc')\n",
    "    img = Image.open(raw_img)\n",
    "    draw = ImageDraw.Draw(img)\n",
    "    for text in texts:\n",
    "        st, ed = tuple(text['rectangle'][0]), tuple(text['rectangle'][1])\n",
    "        T = text['replacementText']\n",
    "        draw.rectangle((st, ed), fill='white')\n",
    "        font_size = proper_font_size(font_file, (ed[0] - st[0], ed[1] - st[1]), T, ratio=1)\n",
    "        font = ImageFont.truetype(font_file, font_size)\n",
    "        text_box = font.getbbox(T)\n",
    "        text_wh = (text_box[2] - text_box[0], text_box[3] - text_box[1])\n",
    "        cx, cy = (st[0] + ed[0]) // 2, st[1]\n",
    "        stx = cx - text_wh[0] // 2\n",
    "        sty = cy - text_wh[1] // 2\n",
    "        draw.text((stx, sty), T, font=font, fill='black')\n",
    "    img.save(tgt_img)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process for no mask images\n",
    "test_ids = set(mrlines(test_split_file))\n",
    "\n",
    "def detect_image_color(image):\n",
    "    gray_image = image.convert('L')\n",
    "    mean_brightness = np.mean(np.array(gray_image))\n",
    "    if mean_brightness < 127:\n",
    "        return 'white'\n",
    "    else:\n",
    "        return 'black'\n",
    "\n",
    "def cover_image(ann_path):\n",
    "    data = load(ann_path)\n",
    "    texts = list(data['text'].values())\n",
    "    raw_img = ann_path.replace('annotations', 'images').replace('.json', '')\n",
    "    tgt_img = raw_img.replace('images', 'images_abc')\n",
    "    img = Image.open(raw_img)\n",
    "    draw = ImageDraw.Draw(img)\n",
    "    color = detect_image_color(img)\n",
    "    font_size = 0\n",
    "    for text in texts:\n",
    "        st, ed = tuple(text['rectangle'][0]), tuple(text['rectangle'][1])\n",
    "        font_size += (ed[1] - st[1])\n",
    "    if len(texts) != 0:\n",
    "        font_size /= len(texts)\n",
    "    else:\n",
    "        font_size = 2\n",
    "    for text in texts:\n",
    "        st, ed = tuple(text['rectangle'][0]), tuple(text['rectangle'][1])\n",
    "        T = text['replacementText']\n",
    "        for i in range(2):\n",
    "            draw.rectangle(\n",
    "                [(st[0] - i, st[1] - i), (ed[0] + i, ed[1] + i)],\n",
    "                outline=color\n",
    "            )\n",
    "        font = ImageFont.truetype(font_file, font_size)\n",
    "        text_box = font.getbbox(T)\n",
    "        text_wh = (text_box[2] - text_box[0], text_box[3] - text_box[1])\n",
    "        cx, cy = (st[0] + ed[0]) // 2, st[1]\n",
    "        stx = cx - text_wh[0] // 2\n",
    "        sty = cy - text_wh[1] * 1.5\n",
    "        if sty < 0:\n",
    "            sty = cy + text_wh[1] * 1.3\n",
    "        draw.text((stx, sty), T, font=font, fill=color)\n",
    "    img.save(tgt_img)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "download_file('https://ai2-public-datasets.s3.amazonaws.com/diagrams/ai2d-all.zip')\n",
    "os.system('unzip -o ai2d-all.zip')\n",
    "\n",
    "images = ls('ai2d/images/')\n",
    "questions = ls('ai2d/questions/')\n",
    "annotations = ls('ai2d/annotations/')\n",
    "cates = load('ai2d/categories.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pool = mp.Pool(32)\n",
    "pool.map(cover_image, annotations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def puncproc(inText):\n",
    "    import re\n",
    "    outText = inText\n",
    "    punct = [\n",
    "        ';', r'/', '[', ']', '\"', '{', '}', '(', ')', '=', '+', '\\\\', '_', '-',\n",
    "        '>', '<', '@', '`', ',', '?', '!'\n",
    "    ]\n",
    "    commaStrip = re.compile('(\\d)(,)(\\d)')  # noqa: W605\n",
    "    periodStrip = re.compile('(?!<=\\d)(\\.)(?!\\d)')  # noqa: W605\n",
    "    for p in punct:\n",
    "        if (p + ' ' in inText or ' ' + p in inText) or (re.search(commaStrip, inText) is not None):\n",
    "            outText = outText.replace(p, '')\n",
    "        else:\n",
    "            outText = outText.replace(p, ' ')\n",
    "    outText = periodStrip.sub('', outText, re.UNICODE)\n",
    "    return outText\n",
    "\n",
    "def check_choices(line):\n",
    "    def ischar(s):\n",
    "        s = str(s)\n",
    "        if s in ['{}', 'Both', 'None of above']:\n",
    "            return True\n",
    "        elif s.startswith('Stage ') and ischar(s[6:]):\n",
    "            return True\n",
    "        elif ' and ' in s and np.all([ischar(x) for x in s.split(' and ')]):\n",
    "            return True\n",
    "        elif len(s) <= 2:\n",
    "            return True\n",
    "        elif len(puncproc(s).split()) > 1:\n",
    "            return np.all([ischar(x) for x in puncproc(s).split()])\n",
    "        return False\n",
    "    n_char = sum([ischar(line[x]) for x in 'ABCD'])\n",
    "    return n_char >= 3\n",
    "\n",
    "def check_question(question):\n",
    "    words = puncproc(question).split()\n",
    "    for ch in string.ascii_lowercase + string.ascii_uppercase:\n",
    "        if ch in words:\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "def is_abc(abc, choices, question):\n",
    "    if abc == 0:\n",
    "        return False\n",
    "    if check_choices(choices):\n",
    "        return True\n",
    "    if check_question(question):\n",
    "        return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_all = defaultdict(list)\n",
    "for qfile in questions:\n",
    "    data = load(qfile)\n",
    "    idx = data['imageName'].split('.')[0]\n",
    "    if idx not in test_ids:\n",
    "        continue\n",
    "    image_pth = qfile.replace('questions', 'images').replace('.json', '')\n",
    "    cate = cates[image_pth.split('/')[-1]]\n",
    "    for q, qmeta in data['questions'].items():\n",
    "        assert '.png-' in qmeta['questionId']\n",
    "        main, sub = qmeta['questionId'].split('.png-')\n",
    "        idx = int(main) * 100 + int(sub)\n",
    "        \n",
    "        answers = qmeta['answerTexts']\n",
    "        correct = qmeta['correctAnswer']\n",
    "        \n",
    "        data_all['index'].append(idx)\n",
    "        data_all['question'].append(q)\n",
    "        assert len(answers) == 4\n",
    "        for c, a in zip('ABCD', answers):\n",
    "            data_all[c].append(a)\n",
    "        data_all['answer'].append('ABCD'[qmeta['correctAnswer']])\n",
    "        data_all['category'].append(cate)\n",
    "        data_all['abcLabel'].append(qmeta['abcLabel'])\n",
    "        abc = is_abc(qmeta['abcLabel'], {x: data_all[x][-1] for x in 'ABCD'}, q)\n",
    "        # if qmeta['abcLabel'] and not abc:\n",
    "        #     print(qmeta['abcLabel'], {x: data_all[x][-1] for x in 'ABCD'}, q)\n",
    "        data_all['image_path'].append(image_pth.replace('images', 'images_abc') if abc else image_pth)\n",
    "data = pd.DataFrame(data_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images = []\n",
    "image_seen = {}\n",
    "for idx, pth in zip(data['index'], data['image_path']):\n",
    "    images.append(encode_image_file_to_base64(pth))\n",
    "\n",
    "data['image'] = images\n",
    "dump(data, 'AI2D_TEST.tsv')\n",
    "print(md5('AI2D_TEST.tsv'))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
